This notebook details a quick analysis we performed on the ESM2 (Evolutionary Scale Modeling) model for masked token prediction. We stumbled upon a counter-intuitive result related to the effect that masking one residue has on the distribution of another.
Before jumping into our results, let’s first establish what masked token prediction is, how it works, and our consequent motivation for this analysis.
Brief primer
At its core, masked token prediction is a fundamental training technique in natural language processing. The basic idea is simple: take a sequence of text, hide some of its tokens (by replacing them with a special “mask” token), and ask the model to predict what should go in those masked positions. This fill-in-the-blanks approach, formally known as masked language modeling (MLM), is remarkably effective at enabling models to learn underlying patterns and dependencies in sequential data1. And in the protein language model space, nearly every state-of-the-art model is trained with masked token prediction.
1 The BERT paper is the key seminal paper that popularized MLM (Devlin et al., 2019).
Once a model is trained on this task, you can mutate a sequence of interest by masking residues and having the model predict what amino acid the mask should be replaced with. This is called unmasking and the procedure goes like this. First, an incomplete sequence is passed through the model. It’s incomplete because at one or more residues, a masked token is placed in lieu of an amino acid. The collection of masked positions is denoted as \(M\) and the partially masked sequence is denoted as \(S_{-M}\). Here’s an example:
\[
S_{-M} = \text{A R <mask> A D I <mask>}
\]
Here, \(M = \lbrace 2, 6 \rbrace\), which is \(0\)-indexed because Python has corrupted our minds. When this sequence is passed through a model with an attached language-modelling head (e.g. an ESM2 model), the output is an array of unnormalized probabilities for each amino acid at the masked positions. We call these values logits, which can be readily converted into an amino acid probabilities for each masked residue2. The amino acid probability distribution of the \(i^\text{th}\) residue is denoted as \(p(i|M)\), where \(i \in M\).
2What are logits? Inside a protein language model, each layer produces neuronal activations as it processes the sequence. These intermediate activations represent increasingly abstract features about the amino acid patterns. The final layer produces specific activations dubbed logits - one logit for each possible amino acid that could appear at a masked position. A high logit value for an amino acid corresponds to a high probability of that amino acid. This blog has more details.
Motivation
With this understanding of masked token prediction in place, we can explore an important tradeoff in how these predictions are made. When performing masked token prediction, there is a balance between efficiency and information utilization. On one extreme, you could take \(S_{-M}\), pass it through the model once, and then unmask each masked token. This would be very computationally efficient. However, this approach precludes the model from leveraging the information gained from each newly unmasked position to inform subsequent predictions. At the other end of the spectrum, you could unmask tokens one at a time: predict the first position, update the sequence with that prediction, pass the updated sequence through the model again to predict the second position, and so on. This iterative approach allows for maximum information utilization but can be really expensive in terms of computational resources. There is also a concern that by unmasking individually, you only make conservative choices that prevent a deeper exploration of sequence space.
All this is to say that developing effective unmasking strategies is an important concern for maximizing the utility of masked-token prediction models. And to do that, it’s worth gaining some insight into how masking at one position affects the outcome at another masked position. Our analysis approaches this question by masking two locations, both individually and jointly, and quantifying any changes in their predicted amino acid probabilities. Specifically, we’ll consider a pair of positions \(i\) and \(j\) and ask:
How does \(p(i | \lbrace i \rbrace)\) compare with \(p(i | \lbrace i, j \rbrace)\)?
How does \(p(j | \lbrace j \rbrace)\) compare with \(p(j | \lbrace i, j \rbrace)\)?
By doing this repeatedly for each pair of residues in the sequence, will we start to capture interdependencies between the sequence positions? Do these statistical dependencies relate back to things we expect, like the attention map of the model, or perhaps the contact map of the amino acids in 3D space? Are there any interpretable patterns at all? We will assess the similarity of these distributions with Jensen-Shannon divergence to find out.
Loading the model
Our analysis will focus on the ESM2 model series (Lin et al., 2023).
A good place to start is a quick survey of the available ESM2 model sizes. The ESM2 architecture was created and trained for 6 different model sizes, which are described here:
from analysis.utils import ModelNamefrom rich.console import Consoleconsole = Console()for model_name in ModelName: console.print(model_name, style="strong")
All of these models are hosted on HuggingFace and can be loaded using HuggingFace’s transformers python package.
The larger models need GPUs to run. We’ll do that later. For now, let’s load up the smallest model for prototyping the analysis.
Note
You can still follow along on your own computer without GPU access.
from transformers import AutoTokenizer, EsmForMaskedLMsmall_model = EsmForMaskedLM.from_pretrained(ModelName.ESM2_8M.value)console.print(small_model)
/opt/miniconda3/envs/paired-token-masking/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning:
IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
In addition to the base ESM architecture, this model has an attached language-model decoder head, which you can see at the bottom of the printed model architecture. It is this part of the model that converts the latent embedding produced by the model into the logits array which we will then transform into amino acid probability distributions.
Getting the probability array \(P\)
Table of all 33 tokens in the model’s vocabulary and their indices in the logits array.
Token
Index
0
<cls>
1
<pad>
2
<eos>
3
<unk>
4
L
5
A
6
G
7
V
8
S
9
E
10
R
11
T
12
I
13
D
14
P
15
K
16
Q
17
N
18
F
19
Y
20
M
21
H
22
W
23
C
24
X
25
B
26
U
27
Z
28
O
29
.
30
-
31
<null_1>
32
<mask>
A forward pass of small_model yields a logits array, which we need to transform into our desired array of amino acid probability distributions (\(P\)). The transformation requires two key steps. Since we only want amino acid probabilities, we first subset the logits to include just the 20 amino acids from the model’s 33-token vocabulary (shown in the table on the right). And second, we need to apply a softmax to convert these selected logits into probabilities. Here’s the code to perform both steps:
import pandas as pdimport torchfrom analysis.utils import amino_acidstokenizer = AutoTokenizer.from_pretrained( ModelName.ESM2_8M.value, clean_up_tokenization_spaces=True,)aa_token_indices = tokenizer.convert_tokens_to_ids(amino_acids)def get_probs_array(logits):"""Given a logits array, return an amino acid probability distribution array. Args: logits: The output from the language-modelling head of an ESM2 model. Shape is (batch_size, seq_len, vocab_size). Returns: A probability array. Shape is (batch_size, seq_len, 20). """ifnotisinstance(logits, torch.Tensor): logits = torch.tensor(logits)# Only include tokens that correspond to amino acids. This excludes special, non-amino acid tokens. logits_subset = logits[..., aa_token_indices]return torch.nn.functional.softmax(logits_subset, dim=-1).detach().numpy()
As a sanity check, let’s create a sequence of repeating leucines. Then we’ll mask one of the residues, and see what the model predicts at that position. Although there is no theoretically “correct” answer for what the model should predict, given that the entire sequence is homogenously leucine, it would be bizarre if the model predicted anything but.
import numpy as npimport pandas as pd# Create a sequence of all leucine, then tokenize it.sequence ="L"*200tokenized = tokenizer([sequence], return_tensors="pt")# Place a mask token in the middle of the tokenized sequence.mask_pos =100tokenized["input_ids"][0, mask_pos] = tokenizer.mask_token_id# Pass the output of the model through our `get_probs_array` function.logits_array = small_model(**tokenized).logits # Shape (batch_size=1, seq_len=202, vocab_size=33)probs_array = get_probs_array(logits_array) # Shape (batch_size=1, seq_len=202, vocab_size=20)# We are only interested in the position we masked, so let's subset to that.# Note that the position is indexed with `mask_pos + 1` because of a preceding "CLS"# token that prefixes the tokenized sequence.probs_array_at_mask = probs_array[0, mask_pos +1, :]# Like a probability distribution should, the elements sum to one (or very nearly)assert np.isclose(probs_array_at_mask.sum().item(), 1.0)probs =dict(zip(amino_acids, probs_array_at_mask, strict=False))probs =dict(sorted(((aa, prob.item()) for aa, prob in probs.items()), key=lambda x: x[1], reverse=True))pd.Series(probs).to_frame(name="Probability")
The probability masses predicted by the model for the masked residue.
Probability
L
0.990615
P
0.001851
I
0.001022
S
0.000988
T
0.000935
R
0.000911
A
0.000728
V
0.000715
F
0.000417
Q
0.000303
H
0.000272
Y
0.000250
D
0.000236
G
0.000217
M
0.000147
K
0.000143
E
0.000115
N
0.000060
W
0.000057
C
0.000018
Encouragingly, the elements sum to one and the model confidently (>99%) predicts leucine.
Sequence of interest
Now let’s introduce a real protein to work with. Any protein will do, really, but we’ve been exploring the use of adenosine deaminase (human) as a template for protein design, so we’ve decided to use that.
We now know how to turn the model output, a logits array, into an array of amino acid probability distributions. This marks the launch point for our comparison of amino acid probability distributions for single- versus double-residue masking. We’ll do this exhaustively for our sequence of interest by establishing two different libraries: the single masks, of which there are \(N\), and the double masks, of which there are \(N \choose 2\).
from itertools import combinationsN =len(sequence)single_masks =list(combinations(range(N), 1))double_masks =list(combinations(range(N), 2))console.print(f"{single_masks[:5]=}")console.print(f"{double_masks[:5]=}")assertlen(single_masks) == Nassertlen(double_masks) == N * (N -1) /2
This is the compute-heavy portion of the analysis. We’re using Modal’s cloud infrastructure to distribute our computations across GPUs of varying power (from T4s to H100s). For each model in our analysis, we predict logits arrays twice: once for the single mask library, and again for the double mask library. The computation becomes increasingly demanding with larger models - ESM2-8M costs just pennies to run, while ESM2-15B costs around $30 due to its need for more powerful GPUs and smaller batch sizes. All of this behavior is encapsulated in the calculate_or_load_logits function imported below.
Note
If you’re following along on your own computer, don’t worry - the logits are tracked in the GitHub repository, so the function call below will simply load pre-computed logits from disk without requiring Modal or GPU access.
We’ve already written a method to convert logits to amino acid probabilties, so let’s go ahead and run that conversion for all the data.
single_probs: dict[ModelName, np.ndarray] = {}double_probs: dict[ModelName, np.ndarray] = {}for model in ModelName: single_probs[model] = get_probs_array(single_logits[model]) double_probs[model] = get_probs_array(double_logits[model])
Calculating probability metrics for each residue pair
With that, we finally have amino acid probability distributions for (a) each individually masked residue and (b) every pairwise combination of doubly-masked residues.
As a reminder, we’re interested in comparing \(p(i | \lbrace i \rbrace)\) to \(p( i | \lbrace i, j \rbrace)\) and \(p(j | \lbrace j \rbrace)\) to \(p( j | \lbrace i, j \rbrace)\). To compare the degree of similarity between probability distributions, we’re going to use Jenson-Shannon (JS) divergence3. We’ll use base-2 so that \(0\) corresponds to identical distributions and \(1\) corresponds to maximally diverged distributions.
3 We originally planned to use Kullback-Leibler divergence, but because it’s not symmetric (i.e.\(KL(x,y) \ne KL(y, x)\)), it’s not formally speaking a distance measure. JS-divergence was designed specifically to resolve this.
Since we’ll have to loop through each residue pair to calculate the JS divergence, it will be useful to calculate some other metrics along the way:
/opt/miniconda3/envs/paired-token-masking/lib/python3.12/site-packages/scipy/spatial/distance.py:1271: RuntimeWarning:
invalid value encountered in sqrt
pairwise_data is a dictionary of dataframes, one for each model. Each dataframe holds the following metrics for each residue pair.
Calculated metrics:
Field
Description
position_i
The \(i^{\text{th}}\) residue position.
position_j
The \(j^{\text{th}}\) residue position.
amino_acid_i
The original amino acid at the \(i^{\text{th}}\) residue.
most_probable_i_i
The amino acid with the highest probability in \(p(i \vert \lbrace i \rbrace)\).
most_probable_i_ij
The amino acid with the highest probability in \(p( i \vert \lbrace i, j \rbrace)\).
amino_acid_j
The original amino acid at the \(j^{\text{th}}\) residue.
most_probable_j_j
The amino acid with the highest probability in \(p(j \vert \lbrace j \rbrace)\).
most_probable_j_ij
The amino acid with the highest probability in \(p( j \vert \lbrace i, j \rbrace)\).
perplex_i_i
The perplexity4 of \(p(i \vert \lbrace i \rbrace)\).
perplex_i_ij
The perplexity of \(p( i \vert \lbrace i, j \rbrace)\).
perplex_j_i
The perplexity of \(p(j \vert \lbrace j \rbrace)\).
perplex_j_ij
The perplexity of \(p( j \vert \lbrace i, j \rbrace)\).
js_div_i
The Jenson-Shannon divergence between \(p(i \vert \lbrace i \rbrace)\) and \(p( i \vert \lbrace i, j \rbrace)\).
js_div_j
The Jenson-Shannon divergence between \(p(j \vert \lbrace j \rbrace)\) and \(p( j \vert \lbrace i, j \rbrace)\).
js_div_avg
(js_div_i + js_div_j) / 2
4What is perplexity? The perplexity of a probability distribution measures how uncertain or “surprised” the model is about its predictions. For amino acid predictions, a perplexity of 20 indicates maximum uncertainty (equal probability for all 20 amino acids), while a perplexity of 1 indicates complete certainty (100% probability for a single amino acid).
$$$$
Also note that the perplexities in this table refer to single residues, but in the field, (pseudo-)perplexity is often used to score a model’s confidence in an entire sequence (e.g. the ESM2 paper (Lin et al., 2023)) via methods developed in Salazar et al. (2020).
Let’s inspect some of the data for the largest model.
Before we compare single- to double-masking results, let’s see what the reconstruction accuracy is for the protein. Relatedly, we can look at the average site perplexity to get an indication of the model’s confidence.
Code
import arcadia_pycolor as apcimport matplotlib.pyplot as pltapc.mpl.setup()def get_recon_acc(df: pd.DataFrame) ->float: matches = df["most_probable_i_i"] == df["amino_acid_i"]return100* matches.sum() /len(matches)def avg_perplexity(df: pd.DataFrame) ->float:return df.perplex_i_i.mean()fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5.25))ax1.plot( [model.name for model in ModelName], [get_recon_acc(pairwise_data[model]) for model in ModelName], marker="o", linestyle="-", color="#1f77b4",)ax1.set_xlabel("Model")ax1.set_ylabel("Reconstruction Accuracy (%)")ax1.tick_params(axis="x", rotation=20)ax1.grid(True)ax2.plot( [model.name for model in ModelName], [avg_perplexity(pairwise_data[model]) for model in ModelName], marker="o", linestyle="-", color="#1f77b4",)ax2.set_xlabel("Model")ax2.set_ylabel("<Perplexity>")ax2.tick_params(axis="x", rotation=20)ax2.grid(True)plt.tight_layout()plt.show()
Figure 1: Single masked token reconstruction accuracy and confidence for all ESM2 model sizes. (left) Fraction of positions that when masked, the model predicts the masked amino acid as most likely. (right) The perplexity of \(p(i \vert \lbrace i \rbrace)\) averaged over sites.
The effects of single- versus double-masking
In the single-masking case, the model has access to all context except the masked token \(i\). However, in double-masking, the model loses access to both tokens \(i\) and \(j\), reducing the available context for making predictions. This reduction in context should lead to greater uncertainty in the model’s predictions. We can quantify this uncertainty using perplexity. Our hypothesis is that \(p(i | \lbrace i,j \rbrace)\) will show higher perplexity than \(p(i | \lbrace i \rbrace)\) because (1) the model has less contextual information to work with and (2) the presence of multiple masks creates more ambiguity about the relationships between the masked tokens.
Figure 2: Histograms of the difference in perplexity between \(p(i | \lbrace i, j \rbrace)\) and \(p(i | \lbrace i \rbrace)\) for all residue pairs.
While the results show that there is on average a very slight shift for higher perplexity when double-masking, the effect is very slight. These data illustrate that amino acid prediction probabilities are on average independent from each other between the single- and double-masked libraries–especially for the larger models.
However, these averages might mask important position-specific dependencies. Even if masking residue \(j\) typically has little effect on predicting residue \(i\), there could be specific pairs of positions where masking \(j\) substantially impacts the prediction of residue \(i\). To identify the extent of these position pairs, we can examine a more targeted metric: for each position \(i\), what is the maximum increase in perplexity caused by masking any other position \(j\)? This analysis will help reveal whether certain amino acid positions have strong dependencies that are hidden when looking at average effects.
Code
import matplotlib.pyplot as pltimport numpy as npfig, axes = plt.subplots(1, len(ModelName), figsize=(5*len(ModelName), 4.3), sharey=True, sharex=True)for ax, model inzip(axes, ModelName, strict=True): df = pairwise_data[model] df["diff"] = df["perplex_i_ij"] - df["perplex_i_i"] diff = df.groupby("position_i")["diff"].max() ax.grid(True) ax.hist(diff, bins=np.linspace(0, 16, 50), edgecolor="black") ax.set_title(f"{model.name}\nAverage max shift: +{diff.mean():.3f}") ax.set_xlabel(r"$ \max_j(PP_{i| \lbrace i, j \rbrace} - PP_{i| \lbrace i \rbrace}) $") ax.set_yscale("log")if ax == axes[0]: ax.set_ylabel("Count")plt.tight_layout()plt.show()
Figure 3: Histograms of the max difference in perplexity between \(p(i | \lbrace i, j \rbrace)\) and \(p(i | \lbrace i \rbrace)\). Compared to Figure 2, only the position \(j\) that causes the largest \(p(i | \lbrace i, j \rbrace)\) for a given \(i\) is included.
These position-specific maximum shifts show that while random position pairs on average share little dependence, for any given position \(i\) there tends to exist at least one position \(j\) that noticeably impacts the model’s confidence in predicting \(i\). Even in the 15B parameter model, which has an average perplexity of less than 2, double-masking can significantly perturb perplexity. To understand the magnitude of this effect in practical terms, let’s see whether \(p(i | \lbrace i \rbrace)\) and \(p(i | \lbrace i,j \rbrace)\) ever predict different amino acids as most probable, as that would indicate positions where context truly alters the model’s understanding of the protein sequence, rather than just lowering its confidence.
Code
import matplotlib.pyplot as pltimport numpy as npaa_changes = []for model in ModelName: df = pairwise_data[model] df["matches"] = df["most_probable_i_i"] != df["most_probable_i_ij"] aa_change = df.groupby("position_i")["matches"].any().sum() aa_changes.append(aa_change)plt.figure(figsize=(11, 6))plt.bar( [model.name for model in ModelName], aa_changes,)plt.title("Positions where double-masking changes the predicted amino acid")plt.xlabel("Model")plt.ylabel(f"Number of positions affected\n(total # residues = {len(sequence)})")plt.show()
Figure 4: The number of positions where the most probable amino acid differs between the distributions \(p(i | \lbrace i \rbrace)\) and \(p(i | \lbrace i,j \rbrace)\).
In the largest model, more than 10% of positions have at least one partner position whose masking changes which amino acid is predicted as most likely. While this reveals clear cases of strong positional dependence, it only captures the most extreme effects - cases where context actually shifts the model’s top prediction. To capture more subtle interactions, where masking position \(j\) meaningfully shifts the probability distribution at position \(i\) without necessarily changing the most likely amino acid, we finally arrive at Jensen-Shannon (JS) divergence.
Visualizing Shannon-Jensen divergence
The relationships between all pairs of positions can be effectively visualized using heatmaps, where each cell \((i,j)\) shows the JS divergence between \(p(i|\lbrace i \rbrace)\) and \(p(i| \lbrace i,j \rbrace)\).
With that in mind, below is a series of heatmaps for the average Jenson-Shannon divergence, with increasing model size.
Code
import analysis.plotting as plottingdef get_js_div_matrix(df: pd.DataFrame) -> np.ndarray: indices =sorted(set(df["position_i"]).union(df["position_j"])) matrix = pd.DataFrame(index=indices, columns=indices) # type: ignorefor _, row in df.iterrows(): matrix.at[row.position_i, row.position_j] = row.js_div_avg matrix.at[row.position_j, row.position_i] = row.js_div_avg np.fill_diagonal(matrix.values, np.nan) matrix_values = matrix.to_numpy()return matrix_values.astype(np.float32)for model in ModelName: size =700 fig = plotting.visualize_js_div_matrix(get_js_div_matrix(pairwise_data[model])) fig.update_layout( width=size, height=size *80/100, ) fig.show()
Figure 5: Heatmap of average JS divergence for each model size. Each cell is colored based on the average of \(JS(p(i \vert \lbrace i \rbrace), p(i \vert \lbrace i,j \rbrace))\) and \(JS(p(j \vert \lbrace j \rbrace), p(j \vert \lbrace i,j \rbrace))\).
(a) The ESM2_8M parameter model.
(b) The ESM2_35M parameter model.
(c) The ESM2_150M parameter model.
(d) The ESM2_650M parameter model.
(e) The ESM2_3B parameter model.
(f) The ESM2_15B parameter model.
Comparison to 3D contact map
A striking pattern emerges and interestingly, it is most prominently observed at the intermediately sized 150M model parameter. It looks an awful lot like a structural contact map, so let’s load up the structure, calculate the pairwise residue distances, and compare.
We’ll plot the JS-divergence for the 150M model on the upper left triangular region of the heatmap and the contact map on the lower right triangular region.
Code
import biotite.structure as strucfrom biotite.structure.io import load_structure# Use alpha carbonsarray = load_structure("input/P00813.pdb")array = array[array.atom_name =="CA"]distance_map = np.zeros((len(array), len(array)))for i inrange(len(array)): distance_map[i, :] = struc.distance(array[i], array)# Define the contact map as the inverse of the distancescontact_map =1/ distance_mapm150_divergence = get_js_div_matrix(pairwise_data[ModelName.ESM2_150M])fig = plotting.compare_to_contact_map(m150_divergence, contact_map)fig.show()
/var/folders/v4/q0wyvpsj4h901lfkjcln68nw0000gn/T/ipykernel_74993/3206566214.py:13: RuntimeWarning:
divide by zero encountered in divide
Figure 6: Comparison between JS divergence and contact map. (upper-left) The heatmap values in Figure 5 for the 150M parameter model. (bottom-right) The contact map. Calculated by taking the inverse distance between alpha-carbon atoms for each pair of residues, and is expressed in units of \(Å^{-1}\).
This basically illustrates that there is some pairwise epistasis being captured, whereby masking one residue affects the probability distribution of a second masked residue, and that the strength of this effect inversely correlates with the physical distance between the residues. Given that the attention maps learned in the ESM models recapitulate protein contact maps, seeing this propagate to the logits isn’t too surprising. Although, unlike the attention maps, which recapitulate contact maps better and better with increasing model size, the correlation with contact map is strongest in intermediate model sizes.
To be sure, here is the same plot for the largest, most capable ESM2 model.
Figure 7: Comparison between JS divergence and contact map. (upper-left) The heatmap values in Figure 5 for the 15B parameter model. The contact map. Calculated by taking the inverse distance between alpha-carbon atoms for each pair of residues, and is expressed in units of \(Å^{-1}\).
Conclusion
The contrast between these two models is quite interesting. The 15B parameter model has a 90% single-mask reconstruction accuracy, whereas for the 150M parameter model the reconstruction accuracy is only 50%. Given this, we expected that the 150M parameter model would perhaps struggle to create residue pair dependencies that are biologically interpretable and that perhaps the 15B parameter model might provide more interpretable answers. However, the above results show the opposite.
We’re having trouble interpreting this result and we’re seeking your opinion. So our question to you is: Why do you think the contact map pattern diminishes in the JS-divergence for the larger models?
References
Devlin J, Chang M-W, Lee K, Toutanova K. (2019). BERT: Pre-training of DeepBidirectionalTransformers for LanguageUnderstanding. https://doi.org/10.48550/arXiv.1810.04805
Lin Z, Akin H, Rao R, Hie B, Zhu Z, Lu W, Smetanin N, Verkuil R, Kabeli O, Shmueli Y, Santos Costa A dos, Fazel-Zarandi M, Sercu T, Candido S, Rives A. (2023). Evolutionary-scale prediction of atomic-level protein structure with a language model. https://doi.org/10.1126/science.ade2574